#!/usr/bin/env python3
'''
Author: Paschalis Agapitos
Project: Mestizajes

Provides 3D PCA visualization utilities for century- and person-level
embedding representations, projecting high-dimensional vectors into a
standardized three-component space, coloring points by historical period,
and optionally overlaying individual persons with nearest-century arrows
for intuitive exploration of temporal embedding structure.
'''

import warnings
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from mpl_toolkits.mplot3d import Axes3D  # noqa: F401 – needed for 3‑D projection

plt.style.use("ggplot")

def plot_centuries_and_persons(
    century_embeddings: dict,
    person_embeddings: dict,
    field_of_activity: str,
    persons_to_plot=None,
    get_period_fn=None,
    period_colors=None,
    with_mean: bool = True,
    with_std: bool = True,
    n_pca_components: int = 3,
    random_state: int = 42,
    annotate_points: bool = True,
    figsize: tuple = (12, 9),
    century_marker_size: int = 350,
    person_marker_size: int = 80,
    show_arrows: bool = True,
    arrow_linewidth: float = 1.5,
    arrow_alpha: float = 0.5,
    similarity_metric: str = "cosine",
    save_plot: bool = True,
    print_summary: bool = True,
    freq_or_embed: str = "embeddings"
):
    """Plot century embeddings together with selected persons in 3‑D PCA space.

    If *persons_to_plot* is **None**, only centuries are plotted. Otherwise, only the
    persons whose names appear in *person_embeddings* are projected and visualised.
    A console summary showing the three closest centuries to each person (according
    to *similarity_metric*) can also be printed.

    Parameters
    ----------
    century_embeddings, person_embeddings : dict[str, np.ndarray]
        Mapping from century / person name to embedding vector of *equal length*.
    field_of_activity : str
        Human activity domain (used only for the plot title and output filename).
    persons_to_plot : None | str | Sequence[str], optional
        Name or list of names to project. Ignored if not found in
        *person_embeddings*.
    get_period_fn : Callable[[int], str], optional
        Maps an integer century to a period label.  A reasonable default is used
        when *None*.
    period_colors : dict[str, str], optional
        Mapping from period label to HEX colour.  Missing labels fall back to
        the colour of "Other".
    with_mean, with_std : bool, optional
        Passed to :class:`sklearn.preprocessing.StandardScaler`.
    n_pca_components, random_state : int, optional
        Passed to :class:`sklearn.decomposition.PCA`.
    annotate_points : bool, optional
        Write century numbers / person names next to the markers.
    figsize : tuple[int, int], optional
        Size passed to :func:`matplotlib.pyplot.figure`.
    century_marker_size, person_marker_size : int, optional
        Marker areas in points².
    show_arrows : bool, optional
        Draw an arrow from each person to its nearest century.
    arrow_linewidth, arrow_alpha : float, optional
        Aesthetics of the quiver arrows.
    similarity_metric : {"cosine", "euclidean"}
        Criterion used to identify the nearest centuries.
    save_plot : bool, optional
        Persist the PNG under ``figures/embeddings_analysis`` when *True*.
    print_summary : bool, optional
        Write a textual summary to ``stdout``.

    Returns
    -------
    dict  with keys ``summary`` (list[dict]) and ``explained_variance`` (np.ndarray)
    """
    # 1.  Period labels and colours
    if get_period_fn is None:
        def default_get_period(century: int) -> str:
            if century in {-5, -4, -3, -2, -1, 1, 2, 3, 4}:
                return "Classical Era"
            elif 5 <= century <= 14:
                return "Middle Ages"
            elif century in {15, 16, 17}:
                return "Early Modern Era"
            elif century in {18, 19, 20}:
                return "Modern Era"
            return "Other"
        get_period_fn = default_get_period

    if period_colors is None:
        period_colors = {
            "Classical Era": "#E69F00",   # orange
            "Middle Ages": "#56B4E9",          # light blue
            "Early Modern Era": "#009E73",          # green
            "Modern Era": "#D55E00",           # blue
        }


    # 2.  Prepare the inputs
    centuries = sorted(century_embeddings)
    X_centuries = np.vstack([century_embeddings[c] for c in centuries])

    # persons_to_plot
    if persons_to_plot is None:
        person_names = []
    else:
        # Accept both a single string and any iterable of strings
        if isinstance(persons_to_plot, str):
            persons_to_plot = [persons_to_plot]
        missing = [p for p in persons_to_plot if p not in person_embeddings]
        if missing:
            warnings.warn(
                f"{len(missing)} name(s) not in person_embeddings and will be skipped: {missing}",
                stacklevel=2,
            )
        person_names = [p for p in persons_to_plot if p in person_embeddings]

    X_persons = (
        np.vstack([person_embeddings[p] for p in person_names])
        if person_names else
        np.empty((0, X_centuries.shape[1]))
    )

    # 3.  Standardisation and PCA
    scaler = StandardScaler(with_mean=with_mean, with_std=with_std)
    Xc_scaled = scaler.fit_transform(X_centuries)
    Xs_scaled = scaler.transform(X_persons) if person_names else X_persons

    pca = PCA(n_components=n_pca_components, random_state=random_state)
    Xc_pca = pca.fit_transform(Xc_scaled)[:, :3]
    Xs_pca = pca.transform(Xs_scaled)[:, :3] if person_names else Xs_scaled
    explained_variance = pca.explained_variance_ratio_[:3]

    # 4.  Helpers for nearest‑century search
    def _top3_cosine(vec):
        vec_n = vec / (np.linalg.norm(vec) + 1e-12)
        mat_n = X_centuries / (
            np.linalg.norm(X_centuries, axis=1, keepdims=True) + 1e-12
        )
        sims = mat_n @ vec_n
        idx = np.argsort(sims)[::-1][:3]
        return [(centuries[i], sims[i]) for i in idx]

    def _top3_euclidean(vec):
        dists = np.linalg.norm(X_centuries - vec, axis=1)
        idx = np.argsort(dists)[:3]
        return [(centuries[i], dists[i]) for i in idx]

    top3_fn = {"cosine": _top3_cosine, "euclidean": _top3_euclidean}.get(similarity_metric)
    if top3_fn is None:
        raise ValueError("similarity_metric must be 'cosine' or 'euclidean'")

    # 5.  Build summary and optionally print it
    summary = [
        {"person": n, f"top3_{similarity_metric}": top3_fn(person_embeddings[n])}
        for n in person_names
    ]

    if print_summary and summary:
        print(f"Nearest centuries ({similarity_metric}) – {field_of_activity}:")
        for row in summary:
            person = row["person"]
            items = row[f"top3_{similarity_metric}"]
            readable = ", ".join(f"{c} ({v:.3f})" for c, v in items)
            print(f"  • {person}: {readable}")

    # 6.  Visualisation
    sns.set_context("paper", font_scale=1.5)
    fig = plt.figure(figsize=figsize, dpi=500)
    ax = fig.add_subplot(111, projection="3d")

    # centuries 
    century_periods = [get_period_fn(c) for c in centuries]
    for period, colour in period_colors.items():
        mask = np.equal(century_periods, period)
        if not any(mask):
            continue
        pts = Xc_pca[mask]
        ax.scatter(
            pts[:, 0], pts[:, 1], pts[:, 2],
            s=century_marker_size, c=colour, alpha=0.7,
            edgecolors="black", linewidth=1, marker="o", label=period,
        )
        if annotate_points:
            for (x, y, z), cent in zip(pts, np.array(centuries)[mask]):
                ax.text(x, y, z, str(cent), 
                        ha="center", va="center")

    # persons
    for i, name in enumerate(person_names):
        nearest_cent, _ = top3_fn(person_embeddings[name])[0]
        period = get_period_fn(nearest_cent)
        colour = period_colors.get(period, period_colors["Other"])
        x, y, z = Xs_pca[i]
        ax.scatter(
            x, y, z,
            s=person_marker_size, c=colour, alpha=0.3,
            edgecolors="black", linewidth=0.5, marker="^",
        )
        if annotate_points:
            ax.text(x, y, z + 0.02, name, ha="center", va="bottom")
        if show_arrows:
            origin = Xc_pca[centuries.index(nearest_cent)]
            ax.quiver(
                origin[0], origin[1], origin[2],
                x - origin[0], y - origin[1], z - origin[2],
                length=1, normalize=False, arrow_length_ratio=0.1,
                linewidth=arrow_linewidth, color=colour, alpha=arrow_alpha,
            )

    # 7.  Axis labels, title, legend
    ax.set_xlabel(f"PC1 ({explained_variance[0] * 100:.1f}%)")
    ax.set_ylabel(f"PC2 ({explained_variance[1] * 100:.1f}%)")
    ax.set_zlabel(f"PC3 ({explained_variance[2] * 100:.1f}%)")

    title_clean = field_of_activity.replace("_", " ").title()
    if person_names:
        title = f"Historical Periods and {title_clean} in PCA Space - {freq_or_embed.title()}"
    else:
        title = f"Historical Periods in PCA Space \n {title_clean} - {freq_or_embed.title()}"
    ax.set_title(title)

    # ax.legend(title="Historical Periods", loc="upper left", bbox_to_anchor=(0.02, 0.98))
    # ax.text2D(
    #     1, 0.002,
    #     "Note: Negative centuries = BCE\n(e.g., −4 = 4th c. BCE)\nPositive centuries = CE\n(e.g., 4 = 4th c. CE)",
    #     transform=ax.transAxes,
    #     fontsize=8, ha="right", va="bottom",
    #     bbox=dict(facecolor="white", edgecolor="black", alpha=0.6),
    # )
    ax.grid(True, alpha=0.3, color="k")

    # plt.tight_layout()
    if save_plot:
        field_clean = (
            field_of_activity.replace(" ", "_").replace("&", "and").lower()
        )
        save_path = Path("../figures/embeddings_analysis/")
        save_path.mkdir(parents=True, exist_ok=True)
        plt.tight_layout()
        plt.savefig(
            save_path / f"century_{field_clean}_pca_{freq_or_embed[:4].lower()}.png",
            dpi=500,
            transparent=True,
        )
    plt.show()

    return {"summary": summary, "explained_variance": explained_variance}